// Copyright 2018 The Cockroach Authors.
// Copyright (c) 2022-present, Shanghai Yunxi Technology Co, Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This software (KWDB) is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of the Mulan PSL v2.
// You may obtain a copy of Mulan PSL v2 at:
//          http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
// EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
// MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

package tree_test

import (
	"bytes"
	"context"
	"flag"
	"fmt"
	"io/ioutil"
	"path/filepath"
	"runtime"
	"strings"
	"testing"

	"gitee.com/kwbasedb/kwbase/pkg/sql/parser"
	_ "gitee.com/kwbasedb/kwbase/pkg/sql/sem/builtins"
	"gitee.com/kwbasedb/kwbase/pkg/sql/sem/tree"
	"gitee.com/kwbasedb/kwbase/pkg/sql/types"
	"gitee.com/kwbasedb/kwbase/pkg/testutils/sqlutils"
	"gitee.com/kwbasedb/kwbase/pkg/util/leaktest"
	"gitee.com/kwbasedb/kwbase/pkg/util/pretty"
	"golang.org/x/sync/errgroup"
)

var (
	flagWritePretty = flag.Bool("rewrite-pretty", false, "rewrite pretty test outputs")
	testPrettyCfg   = func() tree.PrettyCfg {
		cfg := tree.DefaultPrettyCfg()
		cfg.JSONFmt = true
		return cfg
	}()
)

// TestPrettyData reads in a single SQL statement from a file, formats
// it at 40 characters width, and compares that output to a known-good
// output file. It is most useful when changing or implementing the
// doc interface for a node, and should be used to compare and verify
// the changed output.
func TestPrettyDataShort(t *testing.T) {
	defer leaktest.AfterTest(t)()
	matches, err := filepath.Glob(filepath.Join("testdata", "pretty", "*.sql"))
	if err != nil {
		t.Fatal(err)
	}
	if *flagWritePretty {
		t.Log("WARNING: do not forget to run TestPrettyData with build flag 'nightly' and the -rewrite-pretty flag too!")
	}
	cfg := testPrettyCfg
	cfg.Align = tree.PrettyNoAlign
	t.Run("ref", func(t *testing.T) {
		runTestPrettyData(t, "ref", cfg, matches, true /*short*/)
	})
	cfg.Align = tree.PrettyAlignAndDeindent
	t.Run("align-deindent", func(t *testing.T) {
		runTestPrettyData(t, "align-deindent", cfg, matches, true /*short*/)
	})
	cfg.Align = tree.PrettyAlignOnly
	t.Run("align-only", func(t *testing.T) {
		runTestPrettyData(t, "align-only", cfg, matches, true /*short*/)
	})
}

func runTestPrettyData(
	t *testing.T, prefix string, cfg tree.PrettyCfg, matches []string, short bool,
) {
	for _, m := range matches {
		m := m
		t.Run(filepath.Base(m), func(t *testing.T) {
			sql, err := ioutil.ReadFile(m)
			if err != nil {
				t.Fatal(err)
			}
			stmt, err := parser.ParseOne(string(sql))
			if err != nil {
				t.Fatal(err)
			}

			// We have a statement, now we need to format it at all possible line
			// lengths. We use the length of the string + 10 as the upper bound to try to
			// find what happens at the longest line length. Preallocate a result slice and
			// work chan, then fire off a bunch of workers to compute all of the variants.
			var res []string
			if short {
				res = []string{""}
			} else {
				res = make([]string, len(sql)+10)
			}
			type param struct{ idx, numCols int }
			work := make(chan param, len(res))
			if short {
				work <- param{0, 40}
			} else {
				for i := range res {
					work <- param{i, i + 1}
				}
			}
			close(work)
			g, _ := errgroup.WithContext(context.Background())
			worker := func() error {
				for p := range work {
					thisCfg := cfg
					thisCfg.LineWidth = p.numCols
					res[p.idx] = thisCfg.Pretty(stmt.AST)
				}
				return nil
			}
			for i := 0; i < runtime.NumCPU(); i++ {
				g.Go(worker)
			}
			if err := g.Wait(); err != nil {
				t.Fatal(err)
			}
			var sb strings.Builder
			for i, s := range res {
				// Only write each new result to the output, along with a small header
				// indicating the line length.
				if i == 0 || s != res[i-1] {
					fmt.Fprintf(&sb, "%d:\n%s\n%s\n\n", i+1, strings.Repeat("-", i+1), s)
				}
			}
			var gotB bytes.Buffer
			gotB.WriteString("// Code generated by TestPretty. DO NOT EDIT.\n")
			gotB.WriteString("// GENERATED FILE DO NOT EDIT\n")
			gotB.WriteString(sb.String())
			gotB.WriteByte('\n')
			got := gotB.String()

			ext := filepath.Ext(m)
			outfile := m[:len(m)-len(ext)] + "." + prefix + ".golden"
			if short {
				outfile = outfile + ".short"
			}

			if *flagWritePretty {
				if err := ioutil.WriteFile(outfile, []byte(got), 0666); err != nil {
					t.Fatal(err)
				}
				return
			}

			expect, err := ioutil.ReadFile(outfile)
			if err != nil {
				t.Fatal(err)
			}
			if string(expect) != got {
				t.Fatalf("expected:\n%s\ngot:\n%s", expect, got)
			}

			sqlutils.VerifyStatementPrettyRoundtrip(t, string(sql))
		})
	}
}

func TestPrettyVerify(t *testing.T) {
	defer leaktest.AfterTest(t)()
	tests := map[string]string{
		// Verify that INTERVAL is maintained.
		`SELECT interval '-2µs'`: `SELECT '-00:00:00.000002':::INTERVAL`,
	}
	for orig, pretty := range tests {
		t.Run(orig, func(t *testing.T) {
			sqlutils.VerifyStatementPrettyRoundtrip(t, orig)

			stmt, err := parser.ParseOne(orig)
			if err != nil {
				t.Fatal(err)
			}
			got := tree.Pretty(stmt.AST)
			if pretty != got {
				t.Fatalf("got: %s\nexpected: %s", got, pretty)
			}
		})
	}
}

func BenchmarkPrettyData(b *testing.B) {
	matches, err := filepath.Glob(filepath.Join("testdata", "pretty", "*.sql"))
	if err != nil {
		b.Fatal(err)
	}
	var docs []pretty.Doc
	cfg := tree.DefaultPrettyCfg()
	for _, m := range matches {
		sql, err := ioutil.ReadFile(m)
		if err != nil {
			b.Fatal(err)
		}
		stmt, err := parser.ParseOne(string(sql))
		if err != nil {
			b.Fatal(err)
		}
		docs = append(docs, cfg.Doc(stmt.AST))
	}

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		for _, doc := range docs {
			for _, w := range []int{1, 30, 80} {
				pretty.Pretty(doc, w, true /*useTabs*/, 4 /*tabWidth*/, nil /* keywordTransform */)
			}
		}
	}
}

func TestPrettyExprs(t *testing.T) {
	defer leaktest.AfterTest(t)()
	tests := map[tree.Expr]string{
		&tree.CastExpr{
			Expr: tree.NewDString("foo"),
			Type: types.MakeCollatedString(types.String, "en"),
		}: `CAST('foo':::STRING AS STRING) COLLATE en`,
	}

	for expr, pretty := range tests {
		got := tree.Pretty(expr)
		if pretty != got {
			t.Fatalf("got: %s\nexpected: %s", got, pretty)
		}
	}
}
